# -*- coding:utf-8 -*-
########################################################################################################
# The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import numpy as np
import math
import time
import types
import copy
import torch
from torch.nn import functional as F
from src.utils import TOKENIZER, Dataset
from src.model_run import RWKV_RNN
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)

### Step 1: set model ##################################################################################

ctx_len = 1024
n_layer = 6
n_embd = 512
model_type = 'RWKV'           # 'RWKV' or 'RWKV-ffnPre'

# your trained model
MODEL_NAME = 'trained-31'
WORD_NAME = 'vocab'           # the .json vocab (generated by train.py

# ########## Uncomment these to test my 27M params enwik8 model ##########
# MODEL_NAME = 'enwik8-ppl1.65-6064-1024-RWKV-6-512-2022-03-25-21-05-13'
# WORD_NAME = 'enwik8-vocab'
# EVAL_DATA = 'enwik8'  # uncomment this for EVAL MODE (no text generation)
# ########################################################################

# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
# --> all unknown tokens in your context will be denoted by it <--
UNKNOWN_CHAR = ' '   # here we just set it to [space] for simplicity

RUN_DEVICE = 'cpu'   # 'cpu' (already very fast) or 'cuda'
DEBUG_DEBUG = False  # True False - show softmax output

### Step 2: set context ################################################################################

context = "\nIn the"       # ==> this is your prompt

NUM_TRIALS = 999
LENGTH_PER_TRIAL = 500

TEMPERATURE = 1.0
top_p = 0.7
top_p_newline = 0.9

########################################################################################################

print(f'Loading {MODEL_NAME}...')
model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)

########################################################################################################

if 'EVAL_DATA' in vars() or 'EVAL_DATA' in globals():
    print('Evaluating on ' + EVAL_DATA + ' ...')

    data = open(EVAL_DATA, "r", encoding='utf-8').read()

    loss_table = np.zeros(ctx_len)

    N_SAMPLE = 1000

    for iii in range(N_SAMPLE):
        pos = np.random.randint(0, len(data) - ctx_len-1)
        context = data[pos:pos+ctx_len+1]
        ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]

        model.clear()
        for i in range(1, ctx_len+1):
            x = ctx[:i]
            out = model.run(x)
            prob = F.softmax(torch.tensor(out), dim=-1)
            loss_table[i-1] += -math.log(prob[ctx[i]])

        print(f'Tested {iii+1} samples: avg_loss over ctx_len =',
              np.mean(loss_table) / (iii+1))

    exit(0)

########################################################################################################

context = tokenizer.refine_context(context)
print('\nYour prompt has ' + str(len(context)) + ' tokens.')
print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n')

for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
    t_begin = time.time_ns()

    src_len = len(context)
    ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
    print(('-' * 30) + context, end='')

    model.clear()
    if TRIAL == 0:
        init_state = types.SimpleNamespace()
        for i in range(src_len):
            x = ctx[:i+1]
            if i == src_len - 1:
                init_state.out = model.run(x)
            else:
                model.run(x)
        model.save(init_state)
    else:
        model.load(init_state)

    for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
        x = ctx[:i+1]
        x = x[-ctx_len:]

        if i == src_len:
            out = copy.deepcopy(init_state.out)
        else:
            out = model.run(x)
        if DEBUG_DEBUG:
            print('model', np.array(x), '==>', np.array(
                out), np.max(out), np.min(out))

        char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE,
                                       top_p_usual=top_p, top_p_newline=top_p_newline)
        char = char.item()
        print(tokenizer.itos[int(char)], end='', flush=True)
        ctx += [char]
    t_end = time.time_ns()
    print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ')
